--- external/stylegan/metrics/metric_utils.py	2023-04-06 03:52:40.956867999 +0000
+++ external_reference/stylegan/metrics/metric_utils.py	2023-04-06 03:41:03.570631912 +0000
@@ -18,10 +18,14 @@
 import torch
 import dnnlib
 
+from tqdm import tqdm
+from utils import camera_util
+
 #----------------------------------------------------------------------------
 
 class MetricOptions:
-    def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
+    def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1,
+                 rank=0, device=None, progress=None, cache=True, training_mode=None):
         assert 0 <= rank < num_gpus
         self.G              = G
         self.G_kwargs       = dnnlib.EasyDict(G_kwargs)
@@ -31,6 +35,7 @@
         self.device         = device if device is not None else torch.device('cuda', rank)
         self.progress       = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
         self.cache          = cache
+        self.training_mode = training_mode
 
 #----------------------------------------------------------------------------
 
@@ -228,7 +233,16 @@
 
     # Main loop.
     item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
-    for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
+    for image_info, _labels in tqdm(torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs)):
+        if opts.training_mode == 'layout':
+            images = image_info['rgb'] # masked image
+        elif opts.training_mode == 'upsampler':
+            images = image_info['rgb'] # masked image
+        elif opts.training_mode == 'sky':
+            images = image_info['orig'] # full image
+        elif opts.training_mode == 'triplane':
+            images = image_info[:, :3] # first 3 channels are RGB
+
         if images.shape[1] == 1:
             images = images.repeat([1, 3, 1, 1])
         features = detector(images.to(opts.device), **detector_kwargs)
@@ -260,12 +274,41 @@
     progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
     detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
 
+	# get dataset for sky evaluation
+    dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
+    data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
+    from torch_utils import misc
+    random_seed = opts.dataset_kwargs.random_seed
+    sampler = misc.InfiniteSampler(dataset=dataset, rank=opts.rank, num_replicas=opts.num_gpus, seed=random_seed)
+    loader = torch.utils.data.DataLoader(dataset=dataset, sampler=sampler, batch_size=batch_gen, **data_loader_kwargs)
+    iterator = iter(loader)
+
     # Main loop.
     while not stats.is_full():
         images = []
         for _i in range(batch_size // batch_gen):
             z = torch.randn([batch_gen, G.z_dim], device=opts.device)
-            img = G(z=z, c=next(c_iter), **opts.G_kwargs)
+            # import pdb; pdb.set_trace()
+            if opts.training_mode == 'layout':
+                z = torch.randn([batch_gen, G.z_dim], device=opts.device)
+                camera_params = camera_util.get_full_image_parameters(
+                    G, G.layout_decoder_kwargs.nerf_out_res,
+                    batch_size=batch_gen, device=z.device, Rt=None)
+                img, _infos = G(z=z, c=next(c_iter), camera_params=camera_params, **opts.G_kwargs)
+            elif opts.training_mode == 'upsampler':
+                img, thumb = G(z=z, c=next(c_iter), **opts.G_kwargs)
+                img = img[:, :3] # take rgb channels only
+            elif opts.training_mode == 'sky':
+                image_info, _ = next(iterator)
+                real_img_masked = image_info['rgb']
+                real_acc = image_info['acc']
+                real_img_masked = (real_img_masked / 127.5) - 1
+                real_img_masked = real_img_masked * real_acc
+                z = torch.randn([batch_gen, G.z_dim], device=opts.device)
+                img = G(z=z, c=next(c_iter), img=real_img_masked.to(opts.device), acc=real_acc.to(opts.device), **opts.G_kwargs)
+            elif opts.training_mode == 'triplane':
+                img = G(z=z, c=next(c_iter), **opts.G_kwargs)['image']
+                img = img[:, :3] # take rgb channels only
             img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
             images.append(img)
         images = torch.cat(images)
